// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <absl/types/variant.h>

#include <functional>
#include <map>
#include <string>
#include <utility>
#include <vector>

#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/poly/schedule.h"

namespace cinn {
namespace lang {

using compute_handler_t = std::function<Expr(const std::vector<Expr> &)>;
using attr_t = absl::variant<int, float, bool, std::string>;

//! Compute methods for one to five Vars as arguments.
// @{
// The shape are constant integers.
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr()> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape = {});
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape = {});
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape = {});
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape = {});
ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape = {});

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr, Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape = {});

ir::Tensor Compute(const std::vector<Expr> &domain,
                   std::function<Expr(Expr, Expr, Expr, Expr, Expr, Expr)> fn,
                   const std::string &name,
                   const std::vector<Expr> &shape = {});

ir::Tensor Compute(const std::vector<Expr> &domain,
                   compute_handler_t fn,
                   const std::string &name,
                   const std::vector<Expr> &shape = {});
// @}

struct ReturnType {
  Type type;
  std::vector<Expr> dims;
  std::string name;
};

/**
 * \brief Call a lowered function and return one or more tensors as result.
 *
 * A lowered function is generated by lang::Lower method.
 *
 * TODO(Superjomn) Add a registry (symbol table?) to make return result
 * inference automatically.
 *
 * @param func_name The name of the function to call.
 * @param args The readonly arguments(while the mutable tensors are return
 * result).
 * @param return_types The types of the return values.
 * @return Return one or more tensors as result.
 */
std::vector<ir::Tensor> CallLowered(
    const std::string &func_name,
    const std::vector<Expr> &args,
    const std::vector<ReturnType> &return_types);

/**
 * \brief Call an external function and get some tensors as result.
 *
 * There are two kinds of extern functions distinguished by the return type.
 *
 * 1. Void, there are one or more mutable tensors in the argument list.
 * \code
 * Tensor tuple = Compute({M}, []() { return CallExtern("mkl_gemm", {X, W}); });
 * \endcode
 *
 * To support returning multiple value one time, we include the tuple concept,
 * it is a Tensor with CallOp marked with value_offset(from 0 to num_returns-1).
 *
 * 2. POD value, return an expression directly, and it can be inline expand in
 * following computations. \code Tensor tanh_out = Compute({M}, [](Var i) {
 * return CallExtern("tanh", X(i)); }); \endcode
 *
 * Will generate something like
 *
 * \code
 * for (i) {
 *   gemm_mkl(X[i], gemm_out[i])
 * }
 * \endcode
 *
 * @param func_name The name of the function to call.
 * @param args The readonly arguments(while there should be only one tensor as
 * result).
 * @param attrs The readonly attrs.
 */
Expr CallExtern(const std::string &func_name,
                const std::vector<Expr> &args,
                const std::map<std::string, attr_t> &attrs = {});

}  // namespace lang
}  // namespace cinn
